from keras.models import Sequential
from keras.layers.convolutional import Convolution2D, MaxPooling2D, ZeroPadding2D
from keras.layers.core import Dense, Dropout, Activation, Flatten
import h5py
from keras.layers.normalization import BatchNormalization


def set_VGG16_weights(model):
    weights_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/ExData/vgg16.h5'
    f = h5py.File(weights_path)
    flag = 0
    for k in range(len(model.layers)):
        if k >= len(model.layers):
            # we don't look at the last (fully-connected) layers in the savefile
            break
        if model.layers[k].name.startswith('dr_'):
            continue
        g = f['layer_{}'.format(flag)]
        weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
        model.layers[k].set_weights(weights)
        flag += 1
    f.close()


def set_VGG16_finetuned_weights(model):
    weights_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/experiments/224x224/baseline/' \
                   'M-VGG_16_E-3_BS-128_CR-False_CC-True_Aug-False_ConTi-False_L-categorical_crossentropy_CF-softmax.h5'
    f = h5py.File(weights_path)
    for k in range(len(model.layers)):
        if k >= len(model.layers):
            # we don't look at the last (fully-connected) layers in the savefile
            break
        f_name = model.layers[k].name
        g = f[f_name]
        weights = [g['param_{}'.format(p)] for p in range(g.attrs['nb_params'])]
        model.layers[k].set_weights(weights)
    f.close()

def VGG16_finetuned_weights():
    weights_path = '/media/dell/cb552bf1-c649-4cca-8aca-3c24afca817b/dell/wxm/Data/KaggleDDD/experiments/224x224/baseline/' \
                   'M-VGG_16_E-3_BS-128_CR-False_CC-True_Aug-False_ConTi-False_L-categorical_crossentropy_CF-softmax.h5'
    f = h5py.File(weights_path)
    return f

def VGG_16_Extractor(which_feats=None):
    model = Sequential()
    trainable_conv = False
    trainable_fc = False
    model.add(ZeroPadding2D((1, 1), input_shape=(3, 224, 224)))
    model.add(Convolution2D(64, 3, 3, activation='relu', trainable=trainable_conv, name='conv1_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(64, 3, 3, activation='relu', trainable=trainable_conv, name='conv1_2'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_1'))
    if which_feats == 'conv_1':
        set_VGG16_weights(model)
        return model
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(128, 3, 3, activation='relu', trainable=trainable_conv, name='conv2_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(128, 3, 3, activation='relu', trainable=trainable_conv, name='conv2_2'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_2'))
    if which_feats == 'conv_2':
        set_VGG16_weights(model)
        return model
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='relu', trainable=trainable_conv, name='conv3_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='relu', trainable=trainable_conv, name='conv3_2'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='relu', trainable=trainable_conv, name='conv3_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.25, name='d_3'))
    if which_feats == 'conv_3':
        set_VGG16_weights(model)
        return model
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv4_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv4_2'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv4_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.5, name='d_4'))
    if which_feats == 'conv_4':
        set_VGG16_weights(model)
        return model
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv5_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv5_2'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv5_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    # model.add(Dropout(0.5, name='d_5'))
    if which_feats == 'conv_5':
        set_VGG16_weights(model)
        return model
    model.add(Flatten())
    model.add(Dense(4096, activation='relu', trainable=trainable_conv, name="dense_1"))
    model.add(Dropout(0.5))
    model.add(Dense(4096, activation='relu', trainable=trainable_fc, name="dense_2"))
    if which_feats == 'FC7':
        set_VGG16_weights(model)
        return model


def VGG_16_finetuned_Extractor(which_feats=None):
    model = Sequential()
    trainable_conv = False
    trainable_fc = False
    model.add(ZeroPadding2D((1, 1), input_shape=(3, 224, 224)))
    model.add(Convolution2D(64, 3, 3, activation='relu', trainable=trainable_conv, name='conv1_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(64, 3, 3, activation='relu', trainable=trainable_conv, name='conv1_2'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    model.add(Dropout(0.25, name='d_1'))
    if which_feats == 'conv_1':
        set_VGG16_finetuned_weights(model)
        return model
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(128, 3, 3, activation='relu', trainable=trainable_conv, name='conv2_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(128, 3, 3, activation='relu', trainable=trainable_conv, name='conv2_2'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    model.add(Dropout(0.25, name='d_2'))
    if which_feats == 'conv_2':
        set_VGG16_finetuned_weights(model)
        return model
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='relu', trainable=trainable_conv, name='conv3_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='relu', trainable=trainable_conv, name='conv3_2'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(256, 3, 3, activation='relu', trainable=trainable_conv, name='conv3_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    model.add(Dropout(0.25, name='d_3'))
    if which_feats == 'conv_3':
        set_VGG16_finetuned_weights(model)
        return model
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv4_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv4_2'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv4_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    model.add(Dropout(0.5, name='d_4'))
    if which_feats == 'conv_4':
        set_VGG16_finetuned_weights(model)
        return model
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv5_1'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv5_2'))
    model.add(ZeroPadding2D((1, 1)))
    model.add(Convolution2D(512, 3, 3, activation='relu', trainable=trainable_conv, name='conv5_3'))
    model.add(MaxPooling2D((2, 2), strides=(2, 2)))
    model.add(Dropout(0.5, name='d_5'))
    if which_feats == 'conv_5':
        set_VGG16_finetuned_weights(model)
        return model
    model.add(Flatten(name='flatten_1'))
    model.add(Dense(2048, activation='relu', trainable=trainable_conv, name="dense_1"))
    model.add(Dropout(0.5))
    model.add(Dense(1024, activation='relu', trainable=trainable_fc, name="dense_2"))
    if which_feats == 'FC7':
        model.set_weights(VGG16_finetuned_weights())
        return model
